#!/usr/bin/env python
#
# Copyright 2014 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function

import argparse
from ast import literal_eval
from cmd import Cmd
import json
from multiprocessing import Process
import os
from textwrap import dedent

from websocket import create_connection

# The directory to store the output files.
TMP_DIR = '/tmp/qdb/'
# The format of the outfile
FILE_FORMAT = os.path.join(TMP_DIR, '.%s')

_version = '0.1.0'


class QdbRepl(Cmd, object):
    def __init__(self,
                 wsaddress='ws://localhost:8002/{uuid}',
                 auth_msg='',
                 uuid='qdb',
                 prompt='(qdb) '):
        self.auth_msg = auth_msg
        self.printer = None
        self.prompt = prompt
        self.wsaddress = wsaddress
        self.uuid = uuid
        self._ws = None
        super(QdbRepl, self).__init__()

    def _connect(self):
        """
        Connects to the server and sends start.
        """
        final_addr = self.wsaddress.format(uuid=self.uuid)
        print('Connecting to "%s" with auth_msg "%s"'
              % (final_addr, self.auth_msg))
        self._ws = create_connection(final_addr)
        self.send_command('start', self.auth_msg)
        self.filename = FILE_FORMAT % self.uuid
        self.printer = Process(
            target=ServerPrinter,
            args=(self._ws, self.filename),
        )
        self.printer.start()
        print('qdb connected: ' + self.filename)

    def fmt_msg(self, event, payload):
        """
        Formats a message to send to the server.
        """
        if payload is None:
            return json.dumps({'e': event})
        return json.dumps({
            'e': event,
            'p': payload,
        })

    def send_command(self, event, payload=None):
        """
        Sends a formatted command.
        """
        self.send(self.fmt_msg(event, payload))

    def send(self, msg):
        """
        Sends a message to the server.
        """
        self._ws.send(msg)

    def preloop(self):
        """
        Connects to the server before interacting.
        """
        self._connect()

    def parse_break_arg(self, arg, temp=False):
        """
        Parses a breakpoint payload out of an argument.
        """
        if not arg:
            return None
        try:
            return literal_eval(arg)
        except (ValueError, SyntaxError) as e:
            print('*** error: invalid breakpoint dict: %s' % e)
            return None

    def missing_argument(self, cmd):
        print('*** error: %s: missing argument(s)' % cmd)

    def default(self, line):
        """
        eval if no command is given.
        """
        self.send_command('eval', line)

    def do_step(self, arg):
        self.send_command('step')
    do_s = do_step

    def help_step(self):
        print(dedent(
            """\
            s(tep)
            Execute the next line, function call, or return.
            """
        ))
    help_s = help_step


    def do_return(self, arg):
        self.send_command('return')
    do_r = do_return

    def help_return(self):
        print(dedent(
            """\
            r(eturn)
            Execute until the return event for the current stackframe.
            """
        ))
    help_r = help_return

    def do_next(self, arg):
        self.send_command('next')
    do_n = do_next

    def help_next(self):
        print(dedent(
            """\
            n(ext)
            Execute up to the next line in the current frame.
            """
        ))
    help_n = help_next

    def do_until(self, arg):
        self.send_command('until')
    do_unt = do_until

    def help_until(self):
        print(dedent(
            """\
            unt(il)
            Execute until the line greater than the current is hit or until
            you return from the current frame.
            """
        ))
    help_unt = help_until

    def do_continue(self, arg):
        self.send_command('continue')
    do_c = do_continue

    def help_continue(self):
        print(dedent(
            """\
            c(ontinue)
            Continue execution until the next breakpoint is hit. If there are
            no more breakpoints, stop tracing.
            """
        ))
    help_c = help_continue

    def do_watch(self, arg):
        if not arg:
            self.missing_argument('w(atch)')
            return
        self.send_command('set_watch', [arg])
    do_w = do_watch

    def help_watch(self):
        print(dedent(
            """\
            w(atch) EXPR
            Adds an expression to the watchlist.
            """
        ))
    help_w = help_watch

    def do_unwatch(self, arg):
        if not arg:
            self.missing_argument('unw(atch)')
            return
        self.send_command('clear_watch', [arg])
    do_unw = do_unwatch

    def help_unwatch(self):
        print(dedent(
            """\
            unw(atch) EXPR
            Removes an expression from the watchlist if it is already being
            watched, otherwise does nothing.
            """
        ))
    help_unw = help_unwatch

    def do_break(self, arg, temp=False):
        if not arg:
            self.missing_argument('b(reak)')
            return
        break_arg = self.parse_break_arg(arg, temp)
        if break_arg:
            self.send_command('set_break', break_arg)
    do_b = do_break

    def help_break(self):
        print(dedent(
            """\
            b(reak) BREAK-DICT
            Adds a breakpoint with the form:
            {'file': str, 'line': int, 'temp': bool, 'cond': str, 'func': str}
            """
        ))
    help_b = help_break

    def do_clear(self, arg):
        if not arg:
            self.missing_argument('cl(ear)')
            return
        break_arg = self.parse_break_arg(arg)
        if break_arg:
            self.send_command('clear_break', break_arg)
    do_cl = do_clear

    def help_clear(self):
        print(dedent(
            """\
            cl(ear) BREAK-DICT
            Clears a breakpoint with the form:
            {'file': str, 'line': int, 'temp': bool, 'cond': str, 'func': str}
            Only 'file' and 'line' are needed.
            """
        ))
    help_cl = help_clear

    def do_tbreak(self, arg):
        self.do_break(arg, temp=True)

    def help_tbreak(self):
        print(dedent(
            """\
            tbreak BREAK-DICT
            Same as break, but with 'temp' defaulted to True.
            """
        ))

    def do_list(self, arg):
        if not arg:
            self.missing_argument('l(ist)')
            return
        args = arg.split()
        file_ = args[0]
        start = end = None
        try:
            start, end = map(int, args[1:])
        except (ValueError, TypeError):
            pass
        self.send_command('list', {
            'file': file_,
            'start': start,
            'end': end,
        })
    do_l = do_list

    def help_list(self):
        print(dedent(
            """\
            l(ist) FILE [START, [END]]
            Shows the content of a file where START is the first line to show
            and END is the last. This acts like a Python slice.
            """
        ))
    help_l = help_list

    def do_up(self, arg):
        self.send_command('up')
    do_u = do_up

    def help_up(self):
        print(dedent(
            """\
            u(p)
            Steps up a stackframe if possible.
            """
        ))
    help_u = help_up

    def do_down(self, arg):
        self.send_command('down')
    do_d = do_down

    def help_down(self):
        print(dedent(
            """\
            d(own)
            Steps down a stackframe if possible.
            """
        ))
    help_d = help_down

    def do_locals(self, arg):
        self.send_command('locals')

    def help_locals(self):
        print(dedent(
            """\
            locals
            Report back the current stackframe's locals.
            """
        ))

    def do_pause(self, arg):
        self.send_command('pause')
    do_p = do_pause

    def help_pause(self):
        print(dedent(
            """\
            p(ause)
            Pauses execution, stopping at the next allowed event.
            """
        ))
    help_p = help_pause

    def do_quit(self, arg):
        if not arg or arg == 'soft':
            self.send_command('disable', 'soft')
        elif arg == 'hard':
            self.send_command('disable', 'hard')
        else:
            print('*** error: disable: argument must be \'soft\' or \'hard\'')
            return
        self.printer.terminate()
        exit(0)  # I think we're done here
    do_q = do_quit
    do_EOF = do_quit  # EOF soft kills.

    def help_quit(self):
        print(dedent(
            """
            q(uit) [MODE]
            Stops the debugging session with the given mode, defaulting to
            'soft'.
            """
        ))
    help_q = help_quit
    help_EOF = help_quit

    def help_help(self):
        print(dedent(
            """\
            help [COMMAND]
            Without COMMAND, help prints all available commands.
            With COMMAND, help prints the help string for the given command.
            """
        ))


class ServerPrinter(object):
    """
    Manages reading from the server and pretty printing the outstream.
    """
    def __init__(self, socket, filename):
        self.socket = socket
        self.filename = filename
        self._file = open(filename, 'w', 0)
        self.writeln('Tracing...')
        for event in self.get_events():
            evfn = getattr(self, 'event_' + event['e'], None)
            if not evfn:
                self.unknown_event(event['e'])
            else:
                evfn(event.get('p'))

    def get_events(self):
        """
        Yields unpacked events from the socket.
        """
        while True:
            try:
                yield json.loads(self.socket.recv())
            except:
                return

    def writeln(self, msg):
        print(msg, file=self._file)

    def unknown_event(self, e):
        self.writeln('*** error: %s: unknown event type' % e)

    def event_print(self, payload):
        out = payload['output']
        if out:
            self.writeln(
                '# %s%s: %s' % ('***ERROR: ' if payload['exc'] else '',
                                payload['input'],
                                out)
            )

    def event_list(self, payload):
        self.writeln(payload)

    def event_stack(self, payload):
        frame = payload['stack'][payload['index']]  # Current frame
        self.writeln('> %s:%d' % (frame['file'], frame['line']))
        self.writeln('-> ' + frame['code'])

    def event_watchlist(self, payload):
        self.writeln('watchlist: [')
        for watched in payload:
            self.writeln(
                '  > %s%s: %s'
                % ('***ERROR: ' if watched['exc'] else '',
                   watched['expr'], watched['value'])
            )
        self.writeln(']')

    def event_exception(self, payload):
        self.writeln('--* %s:%s' % (payload['type'], payload['value']))

    def event_breakpoints(self, payload):
        self.writeln('breakpoints: [')
        for breakpoint in payload:
            self.writeln('  > %s %d %s %s %s'
                         % (breakpoint['file'],
                            breakpoint['line'],
                            breakpoint['temp'],
                            breakpoint['cond'],
                            breakpoint['func']))
        self.writeln(']')

    def event_error(self, payload):
        self.writeln('*** error: %s: %s' % (payload['type'], payload['data']))

    def event_return(self, payload):
        self.writeln('--> returning with %s' % payload)

    def event_locals(self, payload):
        self.writeln('locals: [')
        for p in payload.iteritems():
            self.writeln('  > %s: %s' % p)
        self.writeln(']')

    def event_disable(self, payload):
        self.writeln('Disabled, goodbye.')


if __name__ == '__main__':
    argparser = argparse.ArgumentParser()
    argparser.add_argument(
        '-w', '--ws-address',
        type=str,
        metavar='ADDR-FMT',
        help='The websocket address format string. containing {{uuid}}',
        default='ws://localhost:8002/{uuid}',
    )
    argparser.add_argument(
        '-u', '--uuid',
        type=str,
        metavar='UUID',
        help='The session uuid that you wish to connect to.',
        default='qdb',
    )
    argparser.add_argument(
        '-a', '--auth-msg',
        type=str,
        metavar='AUTH-MSG',
        help='The authentication message to send with the start event.',
        default='',
    )
    args = argparser.parse_args()
    repl = QdbRepl()
    repl.wsaddress = args.ws_address
    repl.uuid = args.uuid
    repl.auth_msg = args.auth_msg
    if not os.path.isdir(TMP_DIR):
        os.mkdir(TMP_DIR)
    try:
        repl.cmdloop()
    finally:
        try:
            os.remove(repl.filename)  # Cleans up the tmp output file.
        except (AttributeError, OSError):
            pass
